import torch
import torch.nn as nn

A = torch.arange(0, 25).view(1, 1, 5, 5).float()
print(A)
f = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, dilation=1)
B = f(A)
print(B.shape)
